/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    SGD.java
 *    Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 <!-- globalinfo-start --> 
 * Implements stochastic gradient descent for learning
 * various linear models (binary class SVM, binary class logistic regression,
 * squared loss, Huber loss and epsilon-insensitive loss linear regression).
 * Globally replaces all missing values and transforms nominal attributes into
 * binary ones. It also normalizes all attributes, so the coefficients in the
 * output are based on the normalized data.<br/>
 * For numeric class attributes, the squared, Huber or epsilon-insensitve loss
 * function must be used. Epsilon-insensitive and Huber loss may require a much
 * higher learning rate.
 * <p/>
 <!-- globalinfo-end -->
 * 
 <!-- options-start --> 
 * Valid options are:
 * <p/>
 * 
 * <pre>
 * -F
 *  Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
 *  2 = squared loss (regression).
 *  (default = 0)
 * </pre>
 * 
 * <pre>
 * -L
 *  The learning rate. If normalization is
 *  turned off (as it is automatically for streaming data), then the
 *  default learning rate will need to be reduced (try 0.0001).
 *  (default = 0.01).
 * </pre>
 * 
 * <pre>
 * -R &lt;double&gt;
 *  The lambda regularization constant (default = 0.0001)
 * </pre>
 * 
 * <pre>
 * -E &lt;integer&gt;
 *  The number of epochs to perform (batch learning only, default = 500)
 * </pre>
 * 
 * <pre>
 * -C &lt;double&gt;
 *  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
 * </pre>
 * 
 * <pre>
 * -N
 *  Don't normalize the data
 * </pre>
 * 
 * <pre>
 * -M
 *  Don't replace missing values
 * </pre>
 * 
 <!-- options-end -->
 * 
 * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 9785 $
 * 
 */
public class SGD extends RandomizableClassifier implements
    UpdateableClassifier, OptionHandler, Aggregateable<SGD> {

  /** For serialization */
  private static final long serialVersionUID = -3732968666673530290L;

  /** Replace missing values */
  protected ReplaceMissingValues m_replaceMissing;

  /**
   * Convert nominal attributes to numerically coded binary ones. Uses
   * supervised NominalToBinary in the batch learning case
   */
  protected Filter m_nominalToBinary;

  /** Normalize the training data */
  protected Normalize m_normalize;

  /** The regularization parameter */
  protected double m_lambda = 0.0001;

  /** The learning rate */
  protected double m_learningRate = 0.01;

  /** Stores the weights (+ bias in the last element) */
  protected double[] m_weights;

  /** The epsilon parameter for epsilon insensitive and Huber loss */
  protected double m_epsilon = 1e-3;

  /** Holds the current iteration number */
  protected double m_t;

  /** The number of training instances */
  protected double m_numInstances;

  /**
   * The number of epochs to perform (batch learning). Total iterations is
   * m_epochs * num instances
   */
  protected int m_epochs = 500;

  /**
   * Turn off normalization of the input data. This option gets forced for
   * incremental training.
   */
  protected boolean m_dontNormalize = false;

  /**
   * Turn off global replacement of missing values. Missing values will be
   * ignored instead. This option gets forced for incremental training.
   */
  protected boolean m_dontReplaceMissing = false;

  /** Holds the header of the training data */
  protected Instances m_data;

  /**
   * Returns default capabilities of the classifier.
   * 
   * @return the capabilities of this classifier
   */
  @Override
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();
    result.disableAll();

    // attributes
    result.enable(Capability.NOMINAL_ATTRIBUTES);
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enable(Capability.MISSING_VALUES);

    // class
    if (m_loss == SQUAREDLOSS || m_loss == EPSILON_INSENSITIVE
        || m_loss == HUBER)
      result.enable(Capability.NUMERIC_CLASS);
    else
      result.enable(Capability.BINARY_CLASS);
    result.enable(Capability.MISSING_CLASS_VALUES);

    // instances
    result.setMinimumNumberInstances(0);

    return result;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String epsilonTipText() {
    return "The epsilon threshold for epsilon insensitive and Huber "
        + "loss. An error with absolute value less that this "
        + "threshold has loss of 0 for epsilon insensitive loss. "
        + "For Huber loss this is the boundary between the quadratic "
        + "and linear parts of the loss function.";
  }

  /**
   * Set the epsilon threshold on the error for epsilon insensitive and Huber
   * loss functions
   * 
   * @param e the value of epsilon to use
   */
  public void setEpsilon(double e) {
    m_epsilon = e;
  }

  /**
   * Get the epsilon threshold on the error for epsilon insensitive and Huber
   * loss functions
   * 
   * @return the value of epsilon to use
   */
  public double getEpsilon() {
    return m_epsilon;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String lambdaTipText() {
    return "The regularization constant. (default = 0.0001)";
  }

  /**
   * Set the value of lambda to use
   * 
   * @param lambda the value of lambda to use
   */
  public void setLambda(double lambda) {
    m_lambda = lambda;
  }

  /**
   * Get the current value of lambda
   * 
   * @return the current value of lambda
   */
  public double getLambda() {
    return m_lambda;
  }

  /**
   * Set the learning rate.
   * 
   * @param lr the learning rate to use.
   */
  public void setLearningRate(double lr) {
    m_learningRate = lr;
  }

  /**
   * Get the learning rate.
   * 
   * @return the learning rate
   */
  public double getLearningRate() {
    return m_learningRate;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String learningRateTipText() {
    return "The learning rate. If normalization is turned off "
        + "(as it is automatically for streaming data), then"
        + "the default learning rate will need to be reduced ("
        + "try 0.0001).";
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String epochsTipText() {
    return "The number of epochs to perform (batch learning). "
        + "The total number of iterations is epochs * num" + " instances.";
  }

  /**
   * Set the number of epochs to use
   * 
   * @param e the number of epochs to use
   */
  public void setEpochs(int e) {
    m_epochs = e;
  }

  /**
   * Get current number of epochs
   * 
   * @return the current number of epochs
   */
  public int getEpochs() {
    return m_epochs;
  }

  /**
   * Turn normalization off/on.
   * 
   * @param m true if normalization is to be disabled.
   */
  public void setDontNormalize(boolean m) {
    m_dontNormalize = m;
  }

  /**
   * Get whether normalization has been turned off.
   * 
   * @return true if normalization has been disabled.
   */
  public boolean getDontNormalize() {
    return m_dontNormalize;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String dontNormalizeTipText() {
    return "Turn normalization off";
  }

  /**
   * Turn global replacement of missing values off/on. If turned off, then
   * missing values are effectively ignored.
   * 
   * @param m true if global replacement of missing values is to be turned off.
   */
  public void setDontReplaceMissing(boolean m) {
    m_dontReplaceMissing = m;
  }

  /**
   * Get whether global replacement of missing values has been disabled.
   * 
   * @return true if global replacement of missing values has been turned off
   */
  public boolean getDontReplaceMissing() {
    return m_dontReplaceMissing;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String dontReplaceMissingTipText() {
    return "Turn off global replacement of missing values";
  }

  /**
   * Set the loss function to use.
   * 
   * @param function the loss function to use.
   */
  public void setLossFunction(SelectedTag function) {
    if (function.getTags() == TAGS_SELECTION) {
      m_loss = function.getSelectedTag().getID();
    }
  }

  /**
   * Get the current loss function.
   * 
   * @return the current loss function.
   */
  public SelectedTag getLossFunction() {
    return new SelectedTag(m_loss, TAGS_SELECTION);
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String lossFunctionTipText() {
    return "The loss function to use. Hinge loss (SVM), "
        + "log loss (logistic regression) or " + "squared loss (regression).";
  }

  /**
   * Returns an enumeration describing the available options.
   * 
   * @return an enumeration of all the available options.
   */
  @Override
  public Enumeration<Option> listOptions() {

    Vector<Option> newVector = new Vector<Option>();
    newVector.add(new Option("\tSet the loss function to minimize.\n\t0 = "
        + "hinge loss (SVM), 1 = log loss (logistic regression),\n\t"
        + "2 = squared loss (regression), 3 = epsilon insensitive loss (regression)," +
        "\n\t4 = Huber loss (regression).\n\t(default = 0)", "F", 1, "-F"));
    newVector
        .add(new Option(
            "\tThe learning rate. If normalization is\n"
                + "\tturned off (as it is automatically for streaming data), then the\n\t"
                + "default learning rate will need to be reduced "
                + "(try 0.0001).\n\t(default = 0.01).", "L", 1, "-L"));
    newVector.add(new Option("\tThe lambda regularization constant "
        + "(default = 0.0001)", "R", 1, "-R <double>"));
    newVector.add(new Option("\tThe number of epochs to perform ("
        + "batch learning only, default = 500)", "E", 1, "-E <integer>"));
    newVector.add(new Option("\tThe epsilon threshold ("
        + "epsilon-insenstive and Huber loss only, default = 1e-3)", "C", 1,
        "-C <double>"));
    newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
    newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));

    return newVector.elements();
  }

  /**
   * 
   * Parses a given list of options.
   * <p/>
   * 
   <!-- options-start --> 
   * Valid options are:
   * <p/>
   * 
   * <pre>
   * -F
   *  Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
   *  2 = squared loss (regression).
   *  (default = 0)
   * </pre>
   * 
   * <pre>
   * -L
   *  The learning rate. If normalization is
   *  turned off (as it is automatically for streaming data), then the
   *  default learning rate will need to be reduced (try 0.0001).
   *  (default = 0.01).
   * </pre>
   * 
   * <pre>
   * -R &lt;double&gt;
   *  The lambda regularization constant (default = 0.0001)
   * </pre>
   * 
   * <pre>
   * -E &lt;integer&gt;
   *  The number of epochs to perform (batch learning only, default = 500)
   * </pre>
   * 
   * <pre>
   * -C &lt;double&gt;
   *  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
   * </pre>
   * 
   * <pre>
   * -N
   *  Don't normalize the data
   * </pre>
   * 
   * <pre>
   * -M
   *  Don't replace missing values
   * </pre>
   * 
   <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  @Override
  public void setOptions(String[] options) throws Exception {
    reset();

    super.setOptions(options);

    String lossString = Utils.getOption('F', options);
    if (lossString.length() != 0) {
      setLossFunction(new SelectedTag(Integer.parseInt(lossString),
          TAGS_SELECTION));
    }

    String lambdaString = Utils.getOption('R', options);
    if (lambdaString.length() > 0) {
      setLambda(Double.parseDouble(lambdaString));
    }

    String learningRateString = Utils.getOption('L', options);
    if (learningRateString.length() > 0) {
      setLearningRate(Double.parseDouble(learningRateString));
    }

    String epochsString = Utils.getOption("E", options);
    if (epochsString.length() > 0) {
      setEpochs(Integer.parseInt(epochsString));
    }

    String epsilonString = Utils.getOption("C", options);
    if (epsilonString.length() > 0) {
      setEpsilon(Double.parseDouble(epsilonString));
    }

    setDontNormalize(Utils.getFlag("N", options));
    setDontReplaceMissing(Utils.getFlag('M', options));
  }

  /**
   * Gets the current settings of the classifier.
   * 
   * @return an array of strings suitable for passing to setOptions
   */
  @Override
  public String[] getOptions() {
    ArrayList<String> options = new ArrayList<String>();

    options.add("-F");
    options.add("" + getLossFunction().getSelectedTag().getID());
    options.add("-L");
    options.add("" + getLearningRate());
    options.add("-R");
    options.add("" + getLambda());
    options.add("-E");
    options.add("" + getEpochs());
    options.add("-C");
    options.add("" + getEpsilon());
    if (getDontNormalize()) {
      options.add("-N");
    }
    if (getDontReplaceMissing()) {
      options.add("-M");
    }

    return options.toArray(new String[1]);
  }

  /**
   * Returns a string describing classifier
   * 
   * @return a description suitable for displaying in the explorer/experimenter
   *         gui
   */
  public String globalInfo() {
    return "Implements stochastic gradient descent for learning"
        + " various linear models (binary class SVM, binary class"
        + " logistic regression, squared loss, Huber loss and "
        + "epsilon-insensitive loss linear regression)."
        + " Globally replaces all missing values and transforms nominal"
        + " attributes into binary ones. It also normalizes all attributes,"
        + " so the coefficients in the output are based on the normalized"
        + " data.\n" + "For numeric class attributes, the squared, Huber or "
        + "epsilon-insensitve loss function must be used. Epsilon-insensitive "
        + "and Huber loss may require a much higher learning rate.";
  }

  /**
   * Reset the classifier.
   */
  public void reset() {
    m_t = 1;
    m_weights = null;
  }

  /**
   * Method for building the classifier.
   * 
   * @param data the set of training instances.
   * @throws Exception if the classifier can't be built successfully.
   */
  @Override
  public void buildClassifier(Instances data) throws Exception {
    reset();

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    data = new Instances(data);
    data.deleteWithMissingClass();

    if (data.numInstances() > 0 && !m_dontReplaceMissing) {
      m_replaceMissing = new ReplaceMissingValues();
      m_replaceMissing.setInputFormat(data);
      data = Filter.useFilter(data, m_replaceMissing);
    }

    // check for only numeric attributes
    boolean onlyNumeric = true;
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i != data.classIndex()) {
        if (!data.attribute(i).isNumeric()) {
          onlyNumeric = false;
          break;
        }
      }
    }

    if (!onlyNumeric) {
      if (data.numInstances() > 0) {
        m_nominalToBinary = new weka.filters.supervised.attribute.NominalToBinary();
      } else {
        m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
      }
      m_nominalToBinary.setInputFormat(data);
      data = Filter.useFilter(data, m_nominalToBinary);
    }

    if (!m_dontNormalize && data.numInstances() > 0) {

      m_normalize = new Normalize();
      m_normalize.setInputFormat(data);
      data = Filter.useFilter(data, m_normalize);
    }

    m_numInstances = data.numInstances();

    m_weights = new double[data.numAttributes() + 1];
    m_data = new Instances(data, 0);

    if (data.numInstances() > 0) {
      data.randomize(new Random(getSeed())); // randomize the data
      train(data);
    }
  }

  /** the hinge loss function. */
  public static final int HINGE = 0;

  /** the log loss function. */
  public static final int LOGLOSS = 1;

  /** the squared loss function. */
  public static final int SQUAREDLOSS = 2;

  /** The epsilon insensitive loss function */
  public static final int EPSILON_INSENSITIVE = 3;

  /** The Huber loss function */
  public static final int HUBER = 4;

  /** The current loss function to minimize */
  protected int m_loss = HINGE;

  /** Loss functions to choose from */
  public static final Tag[] TAGS_SELECTION = {
      new Tag(HINGE, "Hinge loss (SVM)"),
      new Tag(LOGLOSS, "Log loss (logistic regression)"),
      new Tag(SQUAREDLOSS, "Squared loss (regression)"),
      new Tag(EPSILON_INSENSITIVE, "Epsilon-insensitive loss (SVM regression)"),
      new Tag(HUBER, "Huber loss (robust regression)") };

  protected double dloss(double z) {
    if (m_loss == HINGE) {
      return (z < 1) ? 1 : 0;
    }

    if (m_loss == LOGLOSS) {
      // log loss
      if (z < 0) {
        return 1.0 / (Math.exp(z) + 1.0);
      } else {
        double t = Math.exp(-z);
        return t / (t + 1);
      }
    }

    if (m_loss == EPSILON_INSENSITIVE) {
      if (z > m_epsilon) {
        return 1.0;
      }

      if (-z > m_epsilon) {
        return -1.0;
      }

      return 0;
    }

    if (m_loss == HUBER) {
      if (Math.abs(z) <= m_epsilon) {
        return z;
      } else if (z > 0.0) {
        return m_epsilon;
      } else {
        return -m_epsilon;
      }
    }

    // squared loss
    return z;
  }

  private void train(Instances data) throws Exception {
    for (int e = 0; e < m_epochs; e++) {
      for (int i = 0; i < data.numInstances(); i++) {
        updateClassifier(data.instance(i), false);
      }
    }
  }

  protected static double dotProd(Instance inst1, double[] weights,
      int classIndex) {
    double result = 0;

    int n1 = inst1.numValues();
    int n2 = weights.length - 1;

    for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
      int ind1 = inst1.index(p1);
      int ind2 = p2;
      if (ind1 == ind2) {
        if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
          result += inst1.valueSparse(p1) * weights[p2];
        }
        p1++;
        p2++;
      } else if (ind1 > ind2) {
        p2++;
      } else {
        p1++;
      }
    }
    return (result);
  }

  /**
   * Updates the classifier with the given instance.
   * 
   * @param instance the new training instance to include in the model
   * @param filter true if the instance should pass through any of the filters
   *          set up in buildClassifier(). When batch training buildClassifier()
   *          already batch filters all training instances so don't need to
   *          filter them again here.
   * @exception Exception if the instance could not be incorporated in the
   *              model.
   */
  protected void updateClassifier(Instance instance, boolean filter)
      throws Exception {

    if (!instance.classIsMissing()) {
      if (filter) {
        if (m_replaceMissing != null) {
          m_replaceMissing.input(instance);
          instance = m_replaceMissing.output();
        }

        if (m_nominalToBinary != null) {
          m_nominalToBinary.input(instance);
          instance = m_nominalToBinary.output();
        }

        if (m_normalize != null) {
          m_normalize.input(instance);
          instance = m_normalize.output();
        }
      }

      double wx = dotProd(instance, m_weights, instance.classIndex());

      double y;
      double z;
      if (instance.classAttribute().isNominal()) {
        y = (instance.classValue() == 0) ? -1 : 1;
        z = y * (wx + m_weights[m_weights.length - 1]);
      } else {
        y = instance.classValue();
        z = y - (wx + m_weights[m_weights.length - 1]);
        y = 1;
      }

      // Compute multiplier for weight decay
      double multiplier = 1.0;
      if (m_numInstances == 0) {
        multiplier = 1.0 - (m_learningRate * m_lambda) / m_t;
      } else {
        multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances;
      }
      for (int i = 0; i < m_weights.length - 1; i++) {
        m_weights[i] *= multiplier;
      }

      // Only need to do the following if the loss is non-zero
      // if (m_loss != HINGE || (z < 1)) {
      if (m_loss == SQUAREDLOSS || m_loss == LOGLOSS || m_loss == HUBER
          || (m_loss == HINGE && (z < 1))
          || (m_loss == EPSILON_INSENSITIVE && Math.abs(z) > m_epsilon)) {

        // Compute Factor for updates
        double factor = m_learningRate * y * dloss(z);

        // Update coefficients for attributes
        int n1 = instance.numValues();
        for (int p1 = 0; p1 < n1; p1++) {
          int indS = instance.index(p1);
          if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
            m_weights[indS] += factor * instance.valueSparse(p1);
          }
        }

        // update the bias
        m_weights[m_weights.length - 1] += factor;
      }
      m_t++;
    }
  }

  /**
   * Updates the classifier with the given instance.
   * 
   * @param instance the new training instance to include in the model
   * @exception Exception if the instance could not be incorporated in the
   *              model.
   */
  @Override
  public void updateClassifier(Instance instance) throws Exception {
    updateClassifier(instance, true);
  }

  /**
   * Computes the distribution for a given instance
   * 
   * @param instance the instance for which distribution is computed
   * @return the distribution
   * @throws Exception if the distribution can't be computed successfully
   */
  @Override
  public double[] distributionForInstance(Instance inst) throws Exception {
    double[] result = (inst.classAttribute().isNominal()) ? new double[2]
        : new double[1];

    if (m_replaceMissing != null) {
      m_replaceMissing.input(inst);
      inst = m_replaceMissing.output();
    }

    if (m_nominalToBinary != null) {
      m_nominalToBinary.input(inst);
      inst = m_nominalToBinary.output();
    }

    if (m_normalize != null) {
      m_normalize.input(inst);
      inst = m_normalize.output();
    }

    double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
    double z = (wx + m_weights[m_weights.length - 1]);

    if (inst.classAttribute().isNumeric()) {
      result[0] = z;
      return result;
    }

    if (z <= 0) {
      // z = 0;
      if (m_loss == LOGLOSS) {
        result[0] = 1.0 / (1.0 + Math.exp(z));
        result[1] = 1.0 - result[0];
      } else {
        result[0] = 1;
      }
    } else {
      if (m_loss == LOGLOSS) {
        result[1] = 1.0 / (1.0 + Math.exp(-z));
        result[0] = 1.0 - result[1];
      } else {
        result[1] = 1;
      }
    }
    return result;
  }

  public double[] getWeights() {
    return m_weights;
  }

  /**
   * Prints out the classifier.
   * 
   * @return a description of the classifier as a string
   */
  @Override
  public String toString() {
    if (m_weights == null) {
      return "SGD: No model built yet.\n";
    }
    StringBuffer buff = new StringBuffer();
    buff.append("Loss function: ");
    if (m_loss == HINGE) {
      buff.append("Hinge loss (SVM)\n\n");
    } else if (m_loss == LOGLOSS) {
      buff.append("Log loss (logistic regression)\n\n");
    } else {
      buff.append("Squared loss (linear regression)\n\n");
    }

    buff.append(m_data.classAttribute().name() + " = \n\n");
    int printed = 0;

    for (int i = 0; i < m_weights.length - 1; i++) {
      if (i != m_data.classIndex()) {
        if (printed > 0) {
          buff.append(" + ");
        } else {
          buff.append("   ");
        }

        buff.append(Utils.doubleToString(m_weights[i], 12, 4) + " "
            + ((m_normalize != null) ? "(normalized) " : "")
            + m_data.attribute(i).name() + "\n");

        printed++;
      }
    }

    if (m_weights[m_weights.length - 1] > 0) {
      buff.append(" + "
          + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
    } else {
      buff.append(" - "
          + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
    }

    return buff.toString();
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 9785 $");
  }

  protected int m_numModels = 0;

  /**
   * Aggregate an object with this one
   * 
   * @param toAggregate the object to aggregate
   * @return the result of aggregation
   * @throws Exception if the supplied object can't be aggregated for some
   *           reason
   */
  @Override
  public SGD aggregate(SGD toAggregate) throws Exception {

    if (m_weights == null) {
      throw new Exception("No model built yet, can't aggregate");
    }
    
    if (!m_data.equalHeaders(toAggregate.m_data)) {
      throw new Exception("Can't aggregate - data headers dont match: "
          + m_data.equalHeadersMsg(toAggregate.m_data));
    }
    
    if (m_weights.length != toAggregate.getWeights().length) {
      throw new Exception(
          "Can't aggregate - SDG to aggregate has weight vector "
              + "that differs in length from ours.");
    }        
    
    for (int i = 0; i < m_weights.length; i++) {
      m_weights[i] += toAggregate.getWeights()[i];
    }

    m_numModels++;

    return this;
  }

  /**
   * Call to complete the aggregation process. Allows implementers to do any
   * final processing based on how many objects were aggregated.
   * 
   * @throws Exception if the aggregation can't be finalized for some reason
   */
  @Override
  public void finalizeAggregation() throws Exception {    
    if (m_numModels == 0) {
      throw new Exception("Unable to finalize aggregation - " +
                "haven't seen any models to aggregate");
    }
    
    for (int i = 0; i < m_weights.length; i++) {
      m_weights[i] /= (m_numModels + 1); // plus one for us
    }

    // aggregation complete
    m_numModels = 0;
  }
  
  /**
   * Main method for testing this class.
   */
  public static void main(String[] args) {
    runClassifier(new SGD(), args);
  }
}
